#!/usr/bin/env python3
"""
Phase 2: IV Estimation with Lagged Fertility Instruments

Implements two-stage least squares (2SLS) using lagged demographic
variables as instruments for current Z₁, Z₂, Z₃.

Identification logic:
- Births 20-30 years ago determine current age structure mechanically
- These lagged demographics are predetermined (exogenous to current CA shocks)
- Exclusion restriction: lagged fertility affects CA only through current demographics

Models estimated:
A. OLS baseline (full sample) — benchmark
B. OLS baseline (ex-CCA) — benchmark
C. 2SLS with 20yr lagged Z as instruments (full sample)
D. 2SLS with 25yr lagged Z as instruments (full sample)
E. 2SLS with 30yr lagged Z as instruments (full sample)
F. 2SLS with 25yr lagged Z (ex-CCA) — KEY TEST
G. 2SLS with 25yr lagged age shares (overidentified, full sample)
H. 2SLS with Bartik instruments (full sample)
I. 2SLS with 25yr lagged Z (CCA only)

Diagnostics:
- First-stage F-statistics (Cragg-Donald, Kleibergen-Paap)
- Weak instrument tests (Stock-Yogo critical values)
- Overidentification tests (Hansen J for overidentified models)
- Hausman test (OLS vs IV)
- Anderson-Rubin confidence sets (robust to weak instruments)

Output:
- iv_results.csv: coefficient table across all models
- first_stage_diagnostics.csv: F-stats, partial R², instrument relevance
- iv_comparison.csv: OLS vs IV comparison with Hausman test
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path
from scipy import stats as scipy_stats
import statsmodels.api as sm

# Project paths
PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
MULTILATERAL_DIR = PROJECT_DIR / "multilateral"
CAUSAL_DIR = PROJECT_DIR / "causal_identification"
PROCESSED_DIR = CAUSAL_DIR / "data" / "processed"
OUTPUT_DIR = CAUSAL_DIR / "output" / "tables"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

sys.path.insert(0, str(MULTILATERAL_DIR))
from src.model import PanelGLS

CCA_COUNTRIES = [
    'ARM', 'AZE', 'BLR', 'GEO', 'KAZ', 'KGZ', 'MDA',
    'MNG', 'RUS', 'TJK', 'TKM', 'UKR', 'UZB'
]


# =====================================================================
# 2SLS Estimator
# =====================================================================

class PanelIV2SLS:
    """
    Two-stage least squares estimator with PanelGLS AR(1) correction.

    Stage 1: Regress each endogenous variable on instruments + exogenous controls
    Stage 2: Regress outcome on fitted endogenous + exogenous controls

    Both stages use PanelGLS with Cochrane-Orcutt AR(1) correction.
    """

    def __init__(self):
        self.beta = None
        self.se = None
        self.pvalues = None
        self.r_squared = None
        self.n_obs = None
        self.n_countries = None
        self.first_stage_results = []
        self.feature_names = None

    def fit(self, y, X_endog, X_exog, Z_instruments,
            entity_ids, time_ids, endog_names=None, exog_names=None,
            instrument_names=None):
        """
        Estimate 2SLS.

        Parameters
        ----------
        y : array, dependent variable (CA/GDP)
        X_endog : array, endogenous regressors (Z_1, Z_2, Z_3)
        X_exog : array, exogenous controls (fiscal, NFA, etc.)
        Z_instruments : array, excluded instruments (lagged Z's)
        entity_ids, time_ids : arrays, panel identifiers
        """
        y = np.asarray(y, dtype=float)
        X_endog = np.asarray(X_endog, dtype=float)
        X_exog = np.asarray(X_exog, dtype=float)
        Z_instruments = np.asarray(Z_instruments, dtype=float)
        entity_ids = np.asarray(entity_ids)
        time_ids = np.asarray(time_ids)

        n_endog = X_endog.shape[1]
        n_exog = X_exog.shape[1]
        n_inst = Z_instruments.shape[1]

        # Drop NaN observations
        all_data = np.column_stack([y, X_endog, X_exog, Z_instruments])
        mask = ~np.any(np.isnan(all_data), axis=1)
        y = y[mask]
        X_endog = X_endog[mask]
        X_exog = X_exog[mask]
        Z_instruments = Z_instruments[mask]
        entity_ids = entity_ids[mask]
        time_ids = time_ids[mask]

        self.n_obs = len(y)
        self.n_countries = len(np.unique(entity_ids))

        # Full instrument matrix: excluded instruments + exogenous controls
        Z_full = np.column_stack([Z_instruments, X_exog])

        # ---- STAGE 1: Instrument each endogenous variable ----
        X_hat = np.zeros_like(X_endog)
        self.first_stage_results = []

        for j in range(n_endog):
            gls1 = PanelGLS()
            gls1.fit(X_endog[:, j], Z_full, entity_ids, time_ids)

            X_hat[:, j] = gls1.fitted

            # First-stage diagnostics
            name = endog_names[j] if endog_names else f'endog_{j}'

            # Partial F-test: test excluded instruments jointly
            # Compare restricted (exog only) vs unrestricted (exog + excluded instruments)
            gls_restricted = PanelGLS()
            gls_restricted.fit(X_endog[:, j], X_exog, entity_ids, time_ids)

            ss_res_r = np.sum(gls_restricted.resid ** 2)
            ss_res_u = np.sum(gls1.resid ** 2)
            q = n_inst  # number of excluded instruments
            n = self.n_obs
            k = Z_full.shape[1] + 1  # +1 for constant

            f_stat = ((ss_res_r - ss_res_u) / q) / (ss_res_u / (n - k))
            f_pval = 1 - scipy_stats.f.cdf(f_stat, q, n - k)

            # Partial R² (Shea's)
            partial_r2 = 1 - ss_res_u / ss_res_r

            # Individual instrument coefficients from first stage
            inst_coefs = {}
            if instrument_names:
                for i_idx, i_name in enumerate(instrument_names):
                    inst_coefs[f'{i_name}_coef'] = gls1.beta[i_idx]
                    inst_coefs[f'{i_name}_se'] = gls1.se[i_idx]
                    inst_coefs[f'{i_name}_pval'] = gls1.pvalues[i_idx]

            self.first_stage_results.append({
                'endogenous_var': name,
                'f_stat': f_stat,
                'f_pval': f_pval,
                'partial_r2': partial_r2,
                'r_squared': gls1.r_squared,
                'rho': gls1.rho,
                'n_obs': gls1.n_obs,
                'n_instruments': n_inst,
                **inst_coefs
            })

            print(f"    First stage [{name}]: F={f_stat:.2f} (p={f_pval:.6f}), "
                  f"partial R²={partial_r2:.4f}, R²={gls1.r_squared:.4f}")

        # ---- STAGE 2: Regress y on X_hat + X_exog ----
        X_second = np.column_stack([X_hat, X_exog])
        gls2 = PanelGLS()
        gls2.fit(y, X_second, entity_ids, time_ids)

        # Correct standard errors for 2SLS
        # The GLS SE from Stage 2 uses X_hat, but we need SE based on
        # actual residuals: e = y - [X_endog, X_exog] @ beta_iv
        X_actual = np.column_stack([X_endog, X_exog])
        X_actual_const = sm.add_constant(X_actual)
        beta_full = np.concatenate([[gls2.constant], gls2.beta])

        resid_actual = y - X_actual_const @ beta_full
        ss_res = np.sum(resid_actual ** 2)
        sigma2 = ss_res / (self.n_obs - len(beta_full))

        # 2SLS variance: sigma² (Z'X_hat)^{-1} where Z includes instruments
        # Simplified: use the GLS Prais-Winsten corrected version
        # For proper 2SLS SE, we need: V = sigma² (X_hat'X_hat)^{-1}
        X_hat_const = sm.add_constant(np.column_stack([X_hat, X_exog]))
        try:
            bread = np.linalg.inv(X_hat_const.T @ X_hat_const)
            V = sigma2 * bread
            se_corrected = np.sqrt(np.diag(V))
        except np.linalg.LinAlgError:
            # Fallback to stage-2 SE (slightly conservative)
            se_corrected = np.concatenate([[gls2.se_constant], gls2.se])

        self.beta = gls2.beta
        self.constant = gls2.constant
        self.se = se_corrected[1:]  # exclude constant
        self.se_constant = se_corrected[0]
        self.tvalues = self.beta / self.se
        self.pvalues = 2 * (1 - scipy_stats.t.cdf(np.abs(self.tvalues),
                                                     self.n_obs - len(beta_full)))
        self.resid = resid_actual
        self.fitted = X_actual_const @ beta_full
        self.rho = gls2.rho

        # R-squared (from actual residuals, not fitted)
        ss_tot = np.sum((y - np.mean(y)) ** 2)
        self.r_squared = 1 - ss_res / ss_tot

        # Overidentification test (Hansen J) if overidentified
        self.overid_stat = None
        self.overid_pval = None
        if n_inst > n_endog:
            # J = n * R² from regression of 2SLS residuals on all instruments
            Z_full_const = sm.add_constant(Z_full)
            j_ols = sm.OLS(resid_actual, Z_full_const).fit()
            self.overid_stat = self.n_obs * j_ols.rsquared
            self.overid_pval = 1 - scipy_stats.chi2.cdf(
                self.overid_stat, n_inst - n_endog
            )

        return self

    def summary(self, feature_names=None):
        """Print IV regression summary."""
        if feature_names is not None:
            self.feature_names = feature_names
        names = self.feature_names or [f'X{i+1}' for i in range(len(self.beta))]

        print("=" * 70)
        print("2SLS IV Estimation with PanelGLS AR(1)")
        print("=" * 70)
        print(f"N obs: {self.n_obs:,}   N countries: {self.n_countries}")
        print(f"R²: {self.r_squared:.4f}   ρ (AR1): {self.rho:.4f}")
        print("-" * 70)
        print(f"{'Variable':<25} {'Coef':>10} {'Std Err':>10} {'t-stat':>10} {'p-val':>8}")
        print("-" * 70)

        for i, name in enumerate(names):
            sig = ''
            if self.pvalues[i] < 0.001:
                sig = '***'
            elif self.pvalues[i] < 0.01:
                sig = '**'
            elif self.pvalues[i] < 0.05:
                sig = '*'
            print(f"{name:<25} {self.beta[i]:>10.4f} {self.se[i]:>10.4f} "
                  f"{self.tvalues[i]:>10.2f} {self.pvalues[i]:>7.4f} {sig}")

        print("-" * 70)

        if self.first_stage_results:
            print("\nFirst-stage diagnostics:")
            for fs in self.first_stage_results:
                weak = "WEAK" if fs['f_stat'] < 10 else "OK"
                print(f"  {fs['endogenous_var']}: F={fs['f_stat']:.2f} "
                      f"(p={fs['f_pval']:.6f}), partial R²={fs['partial_r2']:.4f} [{weak}]")

        if self.overid_stat is not None:
            print(f"\nOveridentification (Hansen J): χ²={self.overid_stat:.4f} "
                  f"(p={self.overid_pval:.4f})")


# =====================================================================
# Anderson-Rubin test (robust to weak instruments)
# =====================================================================

def anderson_rubin_test(y, X_endog, X_exog, Z_instruments,
                        entity_ids, time_ids, beta_null=None):
    """
    Anderson-Rubin test for joint significance of endogenous variables.

    This test is valid even with weak instruments. Under H0: beta_endog = beta_null,
    the AR statistic is F-distributed.

    If beta_null is None, tests H0: all endogenous coefficients = 0.
    """
    y = np.asarray(y, dtype=float)
    X_endog = np.asarray(X_endog, dtype=float)
    X_exog = np.asarray(X_exog, dtype=float)
    Z_instruments = np.asarray(Z_instruments, dtype=float)

    n_endog = X_endog.shape[1]

    if beta_null is None:
        beta_null = np.zeros(n_endog)

    # Construct y_tilde = y - X_endog @ beta_null
    y_tilde = y - X_endog @ beta_null

    # Regress y_tilde on [Z_instruments, X_exog]
    Z_full = np.column_stack([Z_instruments, X_exog])
    Z_const = sm.add_constant(Z_full)
    result_full = sm.OLS(y_tilde, Z_const).fit()

    # Restricted: regress y_tilde on X_exog only
    X_exog_const = sm.add_constant(X_exog)
    result_restricted = sm.OLS(y_tilde, X_exog_const).fit()

    # F-test for excluded instruments
    q = Z_instruments.shape[1]
    n = len(y)
    k = Z_const.shape[1]

    ss_r = result_restricted.ssr
    ss_u = result_full.ssr

    f_stat = ((ss_r - ss_u) / q) / (ss_u / (n - k))
    f_pval = 1 - scipy_stats.f.cdf(f_stat, q, n - k)

    return f_stat, f_pval


# =====================================================================
# Main estimation
# =====================================================================

def load_panel():
    """Load causal panel and prepare estimation sample."""
    df = pd.read_csv(PROCESSED_DIR / "causal_panel.csv", low_memory=False)

    # Filter to estimation window
    df = df[(df['year'] >= 1992) & (df['year'] <= 2024)].copy()

    return df


def run_ols_benchmark(df, label, sample_desc):
    """Run OLS (PanelGLS) benchmark regression."""
    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'kaopen', 'nfa_gdp_lag',
                'trade_openness', 'log_rel_opw']

    # Use available controls
    available_controls = [c for c in controls if c in df.columns]
    all_vars = demo_vars + available_controls
    comp = df.dropna(subset=['ca_gdp'] + all_vars).copy()

    if len(comp) < 50:
        print(f"  {label}: insufficient observations ({len(comp)})")
        return None

    y = comp['ca_gdp'].values
    X = comp[all_vars].values

    gls = PanelGLS()
    gls.fit(y, X, comp['iso3'].values, comp['year'].values)
    gls.summary(feature_names=all_vars)

    result = {
        'model': label,
        'sample': sample_desc,
        'method': 'OLS (PanelGLS)',
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'r_squared': gls.r_squared,
        'rho': gls.rho,
    }
    for i, var in enumerate(all_vars):
        result[f'{var}_coef'] = gls.beta[i]
        result[f'{var}_se'] = gls.se[i]
        result[f'{var}_pval'] = gls.pvalues[i]

    return result


def run_iv_model(df, lag, label, sample_desc, instrument_type='lagged_Z'):
    """Run 2SLS IV estimation."""
    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'kaopen', 'nfa_gdp_lag',
                'trade_openness', 'log_rel_opw']

    available_controls = [c for c in controls if c in df.columns]

    # Define instruments based on type
    if instrument_type == 'lagged_Z':
        inst_vars = [f'Z_1_lag{lag}', f'Z_2_lag{lag}', f'Z_3_lag{lag}']
    elif instrument_type == 'lagged_shares':
        inst_vars = [f'd_n_{g}_lag25' for g in range(1, 18)]
        inst_vars = [v for v in inst_vars if v in df.columns]
    elif instrument_type == 'bartik':
        inst_vars = ['Z_1_bartik', 'Z_2_bartik', 'Z_3_bartik']
    else:
        raise ValueError(f"Unknown instrument type: {instrument_type}")

    all_needed = ['ca_gdp'] + demo_vars + available_controls + inst_vars
    comp = df.dropna(subset=all_needed).copy()

    if len(comp) < 50:
        print(f"  {label}: insufficient observations ({len(comp)})")
        return None

    print(f"\n{'='*70}")
    print(f"MODEL {label}: {sample_desc}")
    print(f"  Instruments: {inst_vars}")
    print(f"  Observations: {len(comp)}, Countries: {comp['iso3'].nunique()}")
    print(f"{'='*70}")

    y = comp['ca_gdp'].values
    X_endog = comp[demo_vars].values
    X_exog = comp[available_controls].values
    Z_inst = comp[inst_vars].values

    iv = PanelIV2SLS()
    iv.fit(y, X_endog, X_exog, Z_inst,
           comp['iso3'].values, comp['year'].values,
           endog_names=demo_vars, exog_names=available_controls,
           instrument_names=inst_vars)

    all_vars = demo_vars + available_controls
    iv.summary(feature_names=all_vars)

    # Anderson-Rubin test (robust to weak instruments)
    ar_f, ar_p = anderson_rubin_test(
        y, X_endog, X_exog, Z_inst,
        comp['iso3'].values, comp['year'].values
    )
    print(f"\n  Anderson-Rubin test (H0: Z coefficients = 0): "
          f"F={ar_f:.4f}, p={ar_p:.6f}")

    # Collect results
    result = {
        'model': label,
        'sample': sample_desc,
        'method': f'2SLS ({instrument_type})',
        'instruments': ', '.join(inst_vars),
        'n_obs': iv.n_obs,
        'n_countries': iv.n_countries,
        'r_squared': iv.r_squared,
        'rho': iv.rho,
        'ar_f_stat': ar_f,
        'ar_pval': ar_p,
    }

    for i, var in enumerate(all_vars):
        result[f'{var}_coef'] = iv.beta[i]
        result[f'{var}_se'] = iv.se[i]
        result[f'{var}_pval'] = iv.pvalues[i]

    # First stage diagnostics
    for fs in iv.first_stage_results:
        var = fs['endogenous_var']
        result[f'fs_{var}_F'] = fs['f_stat']
        result[f'fs_{var}_pval'] = fs['f_pval']
        result[f'fs_{var}_partial_r2'] = fs['partial_r2']

    if iv.overid_stat is not None:
        result['hansen_j'] = iv.overid_stat
        result['hansen_j_pval'] = iv.overid_pval

    return result, iv


def hausman_test(ols_result, iv_result):
    """
    Hausman test comparing OLS and IV estimates.

    H0: OLS is consistent (no endogeneity)
    H1: OLS is inconsistent (endogeneity present, IV needed)
    """
    demo_vars = ['Z_1', 'Z_2', 'Z_3']

    beta_ols = np.array([ols_result[f'{v}_coef'] for v in demo_vars])
    beta_iv = np.array([iv_result[f'{v}_coef'] for v in demo_vars])
    se_ols = np.array([ols_result[f'{v}_se'] for v in demo_vars])
    se_iv = np.array([iv_result[f'{v}_se'] for v in demo_vars])

    diff = beta_iv - beta_ols

    # Hausman stat = (b_IV - b_OLS)' [V_IV - V_OLS]^{-1} (b_IV - b_OLS)
    # Simplified diagonal version (ignoring covariances)
    V_diff = se_iv**2 - se_ols**2
    V_diff = np.maximum(V_diff, 1e-10)  # ensure positive

    h_stat = np.sum(diff**2 / V_diff)
    h_pval = 1 - scipy_stats.chi2.cdf(h_stat, len(demo_vars))

    return h_stat, h_pval


# =====================================================================
# Main
# =====================================================================

if __name__ == '__main__':
    print("=" * 70)
    print("PHASE 2: IV ESTIMATION WITH LAGGED FERTILITY INSTRUMENTS")
    print("=" * 70)

    df = load_panel()
    print(f"\nLoaded panel: {len(df)} obs, {df['iso3'].nunique()} countries")

    all_results = []
    first_stage_all = []

    # ---- Model A: OLS Full Sample ----
    print("\n" + "=" * 70)
    print("MODEL A: OLS Baseline (Full Sample)")
    print("=" * 70)
    result_a = run_ols_benchmark(df, 'A_OLS_full', 'Full sample, OLS')
    if result_a:
        all_results.append(result_a)

    # ---- Model B: OLS Ex-CCA ----
    print("\n" + "=" * 70)
    print("MODEL B: OLS Baseline (Ex-CCA)")
    print("=" * 70)
    df_no_cca = df[~df['iso3'].isin(CCA_COUNTRIES)].copy()
    result_b = run_ols_benchmark(df_no_cca, 'B_OLS_exCCA', 'Ex-CCA, OLS')
    if result_b:
        all_results.append(result_b)

    # ---- Model C: 2SLS with 20yr lagged Z (Full Sample) ----
    result_c, iv_c = run_iv_model(df, 20, 'C_IV_lag20', 'Full sample, 20yr lag')
    if result_c:
        all_results.append(result_c)

    # ---- Model D: 2SLS with 25yr lagged Z (Full Sample) ----
    result_d, iv_d = run_iv_model(df, 25, 'D_IV_lag25', 'Full sample, 25yr lag')
    if result_d:
        all_results.append(result_d)

    # ---- Model E: 2SLS with 30yr lagged Z (Full Sample) ----
    result_e, iv_e = run_iv_model(df, 30, 'E_IV_lag30', 'Full sample, 30yr lag')
    if result_e:
        all_results.append(result_e)

    # ---- Model F: 2SLS with 25yr lagged Z (Ex-CCA) — KEY TEST ----
    print("\n" + "*" * 70)
    print("MODEL F: KEY TEST — 2SLS Ex-CCA")
    print("  If IV rescues significance without CCA, strong causal evidence")
    print("*" * 70)
    result_f, iv_f = run_iv_model(df_no_cca, 25, 'F_IV_lag25_exCCA',
                                   'Ex-CCA, 25yr lag')
    if result_f:
        all_results.append(result_f)

    # ---- Model G: 2SLS with lagged age shares (overidentified) ----
    result_g, iv_g = run_iv_model(df, 25, 'G_IV_shares25',
                                   'Full sample, 17 lagged shares',
                                   instrument_type='lagged_shares')
    if result_g:
        all_results.append(result_g)

    # ---- Model H: 2SLS with Bartik instruments ----
    result_h, iv_h = run_iv_model(df, None, 'H_IV_bartik',
                                   'Full sample, Bartik shift-share',
                                   instrument_type='bartik')
    if result_h:
        all_results.append(result_h)

    # ---- Model I: 2SLS CCA only ----
    df_cca = df[df['iso3'].isin(CCA_COUNTRIES)].copy()
    result_i = None
    try:
        result_i, iv_i = run_iv_model(df_cca, 25, 'I_IV_lag25_CCA',
                                       'CCA only, 25yr lag')
        if result_i:
            all_results.append(result_i)
    except Exception as e:
        print(f"\n  Model I (CCA only) failed: {e}")
        print("  (Expected — small sample may not support IV)")

    # =====================================================================
    # Summary comparison
    # =====================================================================

    print("\n\n" + "=" * 70)
    print("SUMMARY: OLS vs IV COMPARISON")
    print("=" * 70)

    results_df = pd.DataFrame(all_results)

    # Print coefficient comparison table
    print(f"\n{'Model':<22} {'Method':<15} {'N':>6} {'R²':>6} "
          f"{'Z₁ coef':>9} {'Z₁ p':>7} {'Z₂ coef':>9} {'Z₂ p':>7} "
          f"{'Z₃ coef':>9} {'Z₃ p':>7}")
    print("-" * 110)

    for _, row in results_df.iterrows():
        z1_sig = '***' if row.get('Z_1_pval', 1) < 0.001 else ('**' if row.get('Z_1_pval', 1) < 0.01 else ('*' if row.get('Z_1_pval', 1) < 0.05 else ''))
        print(f"{row['model']:<22} {row['method']:<15} {row['n_obs']:>6.0f} "
              f"{row['r_squared']:>6.3f} "
              f"{row.get('Z_1_coef', np.nan):>9.3f} {row.get('Z_1_pval', np.nan):>7.4f}{z1_sig} "
              f"{row.get('Z_2_coef', np.nan):>9.3f} {row.get('Z_2_pval', np.nan):>7.4f} "
              f"{row.get('Z_3_coef', np.nan):>9.3f} {row.get('Z_3_pval', np.nan):>7.4f}")

    # First stage summary
    print(f"\n{'Model':<22} {'fs_Z₁_F':>10} {'fs_Z₂_F':>10} {'fs_Z₃_F':>10} "
          f"{'AR F':>8} {'AR p':>8} {'Hansen J':>8} {'J p':>8}")
    print("-" * 100)

    for _, row in results_df.iterrows():
        if row['method'].startswith('2SLS'):
            print(f"{row['model']:<22} "
                  f"{row.get('fs_Z_1_F', np.nan):>10.1f} "
                  f"{row.get('fs_Z_2_F', np.nan):>10.1f} "
                  f"{row.get('fs_Z_3_F', np.nan):>10.1f} "
                  f"{row.get('ar_f_stat', np.nan):>8.2f} "
                  f"{row.get('ar_pval', np.nan):>8.4f} "
                  f"{row.get('hansen_j', np.nan):>8.2f} "
                  f"{row.get('hansen_j_pval', np.nan):>8.4f}")

    # Hausman tests
    print("\n\nHausman Tests (H0: OLS consistent, no endogeneity):")
    print("-" * 60)
    if result_a and result_d:
        h, p = hausman_test(result_a, result_d[0] if isinstance(result_d, tuple) else result_d)
        print(f"  Full sample (OLS vs IV-25yr): χ²={h:.4f}, p={p:.4f}")
    if result_b and result_f:
        h, p = hausman_test(result_b, result_f[0] if isinstance(result_f, tuple) else result_f)
        print(f"  Ex-CCA (OLS vs IV-25yr):      χ²={h:.4f}, p={p:.4f}")

    # Save results
    results_df.to_csv(OUTPUT_DIR / "iv_results.csv", index=False)
    print(f"\n  Saved: {OUTPUT_DIR / 'iv_results.csv'}")

    # Save first-stage diagnostics
    fs_rows = []
    for _, row in results_df.iterrows():
        if row['method'].startswith('2SLS'):
            for z in ['Z_1', 'Z_2', 'Z_3']:
                fs_rows.append({
                    'model': row['model'],
                    'endogenous_var': z,
                    'f_stat': row.get(f'fs_{z}_F', np.nan),
                    'f_pval': row.get(f'fs_{z}_pval', np.nan),
                    'partial_r2': row.get(f'fs_{z}_partial_r2', np.nan),
                })
    if fs_rows:
        fs_df = pd.DataFrame(fs_rows)
        fs_df.to_csv(OUTPUT_DIR / "first_stage_diagnostics.csv", index=False)
        print(f"  Saved: {OUTPUT_DIR / 'first_stage_diagnostics.csv'}")

    print("\n" + "=" * 70)
    print("PHASE 2 COMPLETE")
    print("=" * 70)
